import jax.numpy as np
import importlib
import jax.experimental.optimizers as jaxoptimizers
import scalevi.utils.utils as utils

@jaxoptimizers.optimizer
def custom_momentum_star(step_size: jaxoptimizers.Schedule, mass: float=0.9):
  """Construct optimizer triple for SGD with momentum that uses 
    only the sign of the momentum.

  Args:
    step_size: positive scalar, or a callable representing a step size schedule
      that maps the iteration index to positive scalar.
    mass: positive scalar representing the momentum coefficient.

  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  step_size = jaxoptimizers.make_schedule(step_size)

  def init(x0):
    v0 = np.zeros_like(x0)
    return x0, v0

  def update(i, g, state):
    x, velocity = state
    velocity = mass * velocity + g
    x = x - step_size(i) * np.sign(velocity)
    return x, velocity

  def get_params(state):
    x, _ = state
    return x
  return init, update, get_params

@jaxoptimizers.optimizer
def custom_momentum(step_size: jaxoptimizers.Schedule, mass: float=0.9):
  """Construct optimizer triple for SGD with momentum that uses 
    only the sign of the momentum.

  Args:
    step_size: positive scalar, or a callable representing a step size schedule
      that maps the iteration index to positive scalar.
    mass: positive scalar representing the momentum coefficient.

  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  step_size = jaxoptimizers.make_schedule(step_size)

  def init(x0):
    v0 = np.zeros_like(x0)
    return x0, v0

  def update(i, g, state):
    x, velocity = state
    velocity = mass * velocity + g
    x = x - step_size(i) * velocity
    return x, velocity

  def get_params(state):
    x, _ = state
    return x
  return init, update, get_params

@jaxoptimizers.optimizer
def custom_amsgrad(step_size, b1=0.9, b2=0.999, eps=1e-8):
  """Construct optimizer triple for AMSgrad.

  Args:
    step_size: positive scalar, or a callable representing a step size schedule
      that maps the iteration index to positive scalar.
    b1: optional, a positive scalar value for beta_1, the exponential decay rate
      for the first moment estimates (default 0.9).
    b2: optional, a positive scalar value for beta_2, the exponential decay rate
      for the second moment estimates (default 0.999).
    eps: optional, a positive scalar value for epsilon, a small constant for
      numerical stability (default 1e-8).

  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  step_size = jaxoptimizers.make_schedule(step_size)

  def init(x0):
    m0 = np.zeros_like(x0)
    v0 = np.zeros_like(x0)
    vhat0 = np.zeros_like(x0)
    return x0, m0, v0, vhat0

  def update(i, g, state):
    x, m, v, vhat = state
    m = (1 - b1) * g + b1 * m  # First  moment estimate.
    v = (1 - b2) * np.square(g) + b2 * v  # Second moment estimate.
    vhat = np.maximum(vhat, v)
    x = x - step_size(i) * m / (np.sqrt(vhat) + eps)
    return x, m, v, vhat

  def get_params(state):
    x, _, _, _ = state
    return x
  return init, update, get_params


@jaxoptimizers.optimizer
def custom_adabelief(step_size, b1=0.9, b2=0.999, eps=1e-8, eps2=1e-8):
  """Construct optimizer triple for Adabelief.

  Args:
    step_size: positive scalar, or a callable representing a step size schedule
      that maps the iteration index to positive scalar.
    b1: optional, a positive scalar value for beta_1, the exponential decay rate
      for the first moment estimates (default 0.9).
    b2: optional, a positive scalar value for beta_2, the exponential decay rate
      for the second moment estimates (default 0.999).
    eps: optional, a positive scalar value for epsilon, a small constant for
      numerical stability (default 1e-8).

  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  step_size = jaxoptimizers.make_schedule(step_size)
  def init(x0):
    m0 = np.zeros_like(x0)
    v0 = np.zeros_like(x0)
    return x0, m0, v0
  def update(i, g, state):
    x, m, v = state
    m = (1 - b1) * g + b1 * m  # First  moment estimate.
    v = (1 - b2) * np.square(g - m) + b2 * v + eps2  # Second moment estimate.
    mhat = m / (1 - np.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
    vhat = (v + eps)/ (1 - np.asarray(b2, m.dtype) ** (i + 1))
    x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
    return x, m, v
  def get_params(state):
    x, _, _ = state
    return x
  return init, update, get_params


